#!/usr/bin/env python

import argparse
from collections import OrderedDict
import numpy as np
import hashlib
import sm
import sys
import yaml


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--cam',
        dest='cam_yaml',
        help="Camera configuration as yaml file. This can either be in maplab "
        "format (sensors.yaml/ncamera.yaml) or kalibr format "
        "(camchain-imucam.yaml).",
        required=True)
    parser.add_argument(
        '--imu',
        dest='imu_yaml',
        help="IMU configuration (only used for conversion from kalibr format "
        "to the sensors.yaml maplab format).",
        required=False)
    parser.add_argument(
        '--label',
        dest='ncamera_label',
        help="Name of the NCamera label (only used for conversion from kalibr "
        "format).",
        default='ncamera',
        required=False)
    parser.add_argument(
        '--out',
        dest='output_file_name',
        help="Name of the output file (optional if input file is named "
        "camchain-imucam-* or sensors-*).")
    parser.add_argument(
        '--imu-out',
        dest='imu_output_file_name',
        help="Name of the IMU output file (only for conversions from "
        "maplab to kalibr.")
    parser.add_argument(
        '--to-ncamera',
        dest='to_ncamera',
        help="If set to true, the kalibr calibration file is converted to the "
        "ncamera maplab format.",
        default=False,
        required=False,
        action='store_true')
    return parser.parse_args()


def list_to_ordered_dict(input_list):
    np_array = np.array(input_list)
    shape = np_array.shape
    assert len(shape) < 3
    if len(shape) == 1:
        shape = np_array.reshape((shape[0], 1)).shape
    output_dict = OrderedDict([('rows', shape[0]), ('cols', shape[1]),
                               ('data', np_array.flatten().tolist())])
    return output_dict


def dict_to_list(input_dict):
    shape = [input_dict['cols'], input_dict['rows']]
    return np.squeeze(np.array(input_dict['data']).reshape(shape)).tolist()


class ImuConfig:

    def __init__(self):
        self.__is_initialized = False

    def parse_from_kalibr_format(self, imu_dict):
        assert 'imu0' in imu_dict, "No 'imu0' entry found in the IMU yaml file."
        if len(imu_dict) > 1:
            print(
                "Got more than one IMU. Taking configuration of 'imu0' entry.")
        imu0_dict = imu_dict['imu0']

        self.rostopic = imu0_dict['rostopic']
        self.id = hashlib.md5(self.rostopic.encode('utf-8')).hexdigest()
        self.T_i_b = np.array(imu0_dict['T_i_b'])
        self.time_offset = imu0_dict['time_offset']
        self.update_rate = imu0_dict['update_rate']

        # Gyro noise density. [rad/s*1/sqrt(Hz)]
        self.gyro_noise_density = imu0_dict['gyroscope_noise_density']
        # Accelerometer noise density. [m/s^2*1/sqrt(Hz)]
        self.acc_noise_density = imu0_dict['accelerometer_noise_density']
        # Gyro bias random walk. [rad/s^2*1/sqrt(Hz)]
        self.gyro_random_walk = imu0_dict['gyroscope_random_walk']
        # Accelerometer bias random walk. [m/s^3*1/sqrt(Hz)]
        self.acc_random_walk = imu0_dict['accelerometer_random_walk']

        self.saturation_accel_max_mps2 = 150.0
        self.saturation_gyro_max_radps = 7.5
        self.gravity_magnitude_mps2 = 9.808083883386614
        self.__is_initialized = True

    def parse_from_maplab_format(self, imu_dict):
        self.rostopic = imu_dict['hardware_id']
        self.id = imu_dict['id']
        self.T_i_b = np.eye(4)
        self.time_offset = 0.0
        self.update_rate = 200.0

        # Gyro noise density. [rad/s*1/sqrt(Hz)]
        self.gyro_noise_density = imu_dict['sigmas']['gyro_noise_density']
        # Accelerometer noise density. [m/s^2*1/sqrt(Hz)]
        self.acc_noise_density = imu_dict['sigmas']['acc_noise_density']
        # Gyro bias random walk. [rad/s^2*1/sqrt(Hz)]
        self.gyro_random_walk = imu_dict['sigmas'][
            'gyro_bias_random_walk_noise_density']
        # Accelerometer bias random walk. [m/s^3*1/sqrt(Hz)]
        self.acc_random_walk = imu_dict['sigmas'][
            'acc_bias_random_walk_noise_density']

        self.saturation_accel_max_mps2 = imu_dict['saturation_accel_max_mps2']
        self.saturation_gyro_max_radps = imu_dict['saturation_gyro_max_radps']
        self.gravity_magnitude_mps2 = imu_dict['gravity_magnitude_mps2']
        self.__is_initialized = True

    def get_kalibr_format(self):
        assert self.__is_initialized
        imu_dict = OrderedDict()
        imu_dict['imu0'] = OrderedDict()
        imu_dict['imu0']['T_i_b'] = self.T_i_b.tolist()
        imu_dict['imu0']['accelerometer_noise_density'] = self.acc_noise_density
        imu_dict['imu0']['accelerometer_random_walk'] = self.acc_random_walk
        imu_dict['imu0']['gyroscope_noise_density'] = self.gyro_noise_density
        imu_dict['imu0']['gyroscope_random_walk'] = self.gyro_random_walk
        imu_dict['imu0']['model'] = 'calibrated'
        imu_dict['imu0']['rostopic'] = self.rostopic
        imu_dict['imu0']['time_offset'] = self.time_offset
        imu_dict['imu0']['update_rate'] = self.update_rate
        return imu_dict

    def get_maplab_format(self):
        assert self.__is_initialized
        imu_dict = OrderedDict()
        imu_dict['sensor_type'] = 'IMU'
        imu_dict['hardware_id'] = self.rostopic
        imu_dict['id'] = self.id
        imu_dict['sigmas'] = OrderedDict(
            [('gyro_noise_density', self.gyro_noise_density),
             ('gyro_bias_random_walk_noise_density', self.gyro_random_walk),
             ('acc_noise_density',
              self.acc_noise_density), ('acc_bias_random_walk_noise_density',
                                        self.acc_random_walk)])
        imu_dict['saturation_accel_max_mps2'] = self.saturation_accel_max_mps2
        imu_dict['saturation_gyro_max_radps'] = self.saturation_gyro_max_radps
        imu_dict['gravity_magnitude_mps2'] = self.gravity_magnitude_mps2
        return imu_dict


class CameraConfig:

    def __init__(self, name):
        self.name = name
        self.__is_initialized = False
        self.id = hashlib.md5(name.encode('utf-8')).hexdigest()
        self.line_delay_nanoseconds = 0
        self.cam_overlaps = list()
        self.timeshift_cam_imu = 0.0

    def parse_from_kalibr_format(self, cam_dict):
        assert all(
            keys in cam_dict
            for keys in
            ('T_cam_imu', 'cam_overlaps', 'camera_model', 'distortion_coeffs',
             'distortion_model', 'intrinsics', 'resolution', 'rostopic'))
        self.T_cn_i = sm.Transformation(np.array(cam_dict['T_cam_imu']))
        self.cam_overlaps = cam_dict['cam_overlaps']
        self.camera_model = cam_dict['camera_model']
        self.distortion_coeffs = cam_dict['distortion_coeffs']
        self.distortion_model = cam_dict['distortion_model']
        self.intrinsics = cam_dict['intrinsics']
        self.resolution = cam_dict['resolution']
        self.rostopic = cam_dict['rostopic'].replace('/image_raw', '')
        if 'timeshift_cam_imu' in cam_dict:
            self.timeshift_cam_imu = cam_dict['timeshift_cam_imu']
        self.__is_initialized = True

    def parse_from_maplab_format(self, cam_dict):
        assert all(keys in cam_dict for keys in ('camera', 'T_B_C'))
        assert all(keys in cam_dict['camera']
                   for keys in
                   ('label', 'id', 'line-delay-nanoseconds', 'image_height',
                    'image_width', 'type', 'intrinsics', 'distortion'))
        self.T_cn_i = sm.Transformation(
            np.array(dict_to_list(cam_dict['T_B_C']))).inverse()
        self.rostopic = cam_dict['camera']['label']
        self.id = cam_dict['camera']['id']
        self.line_delay_nanoseconds = cam_dict['camera'][
            'line-delay-nanoseconds']
        self.resolution = [
            cam_dict['camera']['image_width'],
            cam_dict['camera']['image_height']
        ]
        self.camera_model = cam_dict['camera']['type']
        self.intrinsics = dict_to_list(cam_dict['camera']['intrinsics'])
        self.distortion_model = cam_dict['camera']['distortion']['type']
        if self.distortion_model == 'radial-tangential':
            self.distortion_model = 'radtan'
        self.distortion_coeffs = dict_to_list(
            cam_dict['camera']['distortion']['parameters'])
        self.__is_initialized = True

    def get_kalibr_format(self, T_cnm1_i=None):
        assert self.__is_initialized
        cam_dict = OrderedDict()
        cam_dict['camera_model'] = self.camera_model
        cam_dict['intrinsics'] = self.intrinsics
        cam_dict['distortion_model'] = self.distortion_model
        cam_dict['distortion_coeffs'] = self.distortion_coeffs
        if not T_cnm1_i is None:
            T_cn_cnm1 = self.T_cn_i * T_cnm1_i.inverse()
            cam_dict['T_cn_cnm1'] = T_cn_cnm1.T().tolist()
        cam_dict['T_cam_imu'] = self.T_cn_i.T().tolist()
        cam_dict['cam_overlaps'] = self.cam_overlaps
        cam_dict['timeshift_cam_imu'] = self.timeshift_cam_imu
        if not self.rostopic.endswith('/image_raw'):
            cam_dict['rostopic'] = self.rostopic + '/image_raw'
        else:
            cam_dict['rostopic'] = self.rostopic
        cam_dict['resolution'] = self.resolution
        return cam_dict

    def get_maplab_format(self):
        assert self.__is_initialized
        cam_dict = OrderedDict()
        cam_dict['camera'] = OrderedDict()
        cam_dict['camera']['label'] = self.rostopic
        cam_dict['camera']['id'] = self.id
        cam_dict['camera'][
            'line-delay-nanoseconds'] = self.line_delay_nanoseconds
        cam_dict['camera']['image_height'] = self.resolution[1]
        cam_dict['camera']['image_width'] = self.resolution[0]
        cam_dict['camera']['type'] = self.camera_model
        cam_dict['camera']['intrinsics'] = list_to_ordered_dict(self.intrinsics)
        maplab_distortion_model = self.distortion_model
        if maplab_distortion_model == 'radtan':
            maplab_distortion_model = 'radial-tangential'
        cam_dict['camera']['distortion'] = OrderedDict(
            [('type', maplab_distortion_model),
             ('parameters', list_to_ordered_dict(self.distortion_coeffs))])
        cam_dict['T_B_C'] = list_to_ordered_dict(
            self.T_cn_i.inverse().T().tolist())
        return cam_dict


def setup_yaml():
    ''' Allows to use OrderedDict() within yaml dumping.'''
    represent_dict_order = lambda self, data:  self.represent_mapping('tag:yaml.org,2002:map', data.items())
    yaml.add_representer(OrderedDict, represent_dict_order)


class IndentDumper(yaml.Dumper):

    def increase_indent(self, flow=False, indentless=False):
        return super(IndentDumper, self).increase_indent(flow, False)


def read_yaml(file_name):
    yaml_file = open(file_name, 'r')
    try:
        return yaml.load(yaml_file, yaml.Loader)
    except yaml.YAMLError as exception:
        print(exception)


def write_yaml(file_name,
               output_dict,
               default_flow_style=None,
               use_indent_dumper=False):
    output_file = open(file_name, 'w')
    try:
        if not use_indent_dumper:
            yaml.dump(
                output_dict,
                output_file,
                default_flow_style=default_flow_style,
                width=2147483647)
        else:
            yaml.dump(
                output_dict,
                output_file,
                Dumper=IndentDumper,
                default_flow_style=default_flow_style,
                width=2147483647)
    except yaml.YAMLError as exception:
        print(exception)


class CalibrationConfig:

    def __init__(self, parsed_args):
        self.imu = ImuConfig()
        self.__has_imu = False
        self.ncamera_label = parsed_args.ncamera_label
        self.id = hashlib.md5(self.ncamera_label.encode('utf-8')).hexdigest()
        self.cameras = dict()
        self.input_file_name = parsed_args.cam_yaml
        data = read_yaml(self.input_file_name)
        self.__extract_data(data)
        self.to_ncamera = parsed_args.to_ncamera
        if self.__is_kalibr_format:
            if not self.to_ncamera:
                if parsed_args.imu_yaml is None:
                    sys.exit("No IMU yaml file given, exiting.")
                else:
                    imu_data = read_yaml(parsed_args.imu_yaml)
                    self.imu.parse_from_kalibr_format(imu_data)
                    self.__has_imu = True
        self.output_file_name = parsed_args.output_file_name
        self.imu_output_file_name = parsed_args.imu_output_file_name

    def write_converted_yaml(self):
        if self.__is_kalibr_format:
            self.__write_maplab_data()
        else:
            self.__write_kalibr_data()

    def __extract_data(self, data):
        if 'cam0' in data:
            self.__is_kalibr_format = True
            self.__extract_kalibr_data(data)
        elif 'ncameras' in data:
            self.__is_kalibr_format = False
            self.__extract_maplab_data(data)
        elif 'cameras' in data:
            self.__is_kalibr_format = False
            modified_data = {'ncameras': [data]}
            self.__extract_maplab_data(modified_data, extract_imu=False)
        else:
            sys.exit("Unknown calibration format. Exiting.")

    def __extract_kalibr_data(self, data):
        self.num_cameras = len(data)
        for cam_index in range(self.num_cameras):
            cam_name = 'cam{0}'.format(cam_index)
            self.cameras[cam_name] = CameraConfig(cam_name)
            self.cameras[cam_name].parse_from_kalibr_format(data[cam_name])
        print("Got a calibration file in kalibr format.")

    def __extract_maplab_data(self, data, extract_imu=True):
        self.ncamera_label = data['ncameras'][0]['label']
        self.id = data['ncameras'][0]['id']
        cameras_dict = data['ncameras'][0]['cameras']
        self.num_cameras = len(cameras_dict)
        for cam_index in range(self.num_cameras):
            cam_name = 'cam{0}'.format(cam_index)
            self.cameras[cam_name] = CameraConfig(cam_name)
            self.cameras[cam_name].parse_from_maplab_format(
                cameras_dict[cam_index])
        if extract_imu:
            imu_data = data['sensors'][0]
            self.imu.parse_from_maplab_format(imu_data)
            self.__has_imu = True
        print("Got a calibration file in maplab format.")

    def __write_kalibr_data(self):
        camchain_dict = OrderedDict()
        T_cnm1_i = None
        for cam_index in range(self.num_cameras):
            cam_name = 'cam{0}'.format(cam_index)
            camchain_dict[cam_name] = self.cameras[cam_name].get_kalibr_format(
                T_cnm1_i)
            T_cnm1_i = self.cameras[cam_name].T_cn_i
        write_yaml(self.__get_output_file_name(), camchain_dict)
        print("Wrote " + str(self.__get_output_file_name()) + ".")
        if not self.imu_output_file_name is None:
            assert self.__has_imu
            imu_dict = self.imu.get_kalibr_format()
            write_yaml(self.imu_output_file_name, imu_dict)
            print("Wrote " + str(self.imu_output_file_name) + ".")

    def __write_maplab_data(self):
        if not self.to_ncamera:
            assert self.__has_imu
        sensors_dict = OrderedDict()
        if not self.to_ncamera:
            sensors_dict['sensors'] = [self.imu.get_maplab_format()]
        sensors_dict['ncameras'] = [
            OrderedDict([('label', self.ncamera_label), ('id', self.id),
                         ('cameras', list())])
        ]
        for cam_index in range(self.num_cameras):
            cam_name = 'cam{0}'.format(cam_index)
            sensors_dict['ncameras'][0]['cameras'].append(
                self.cameras[cam_name].get_maplab_format())
        default_flow_style = None
        use_indent_dumper = True
        if not self.to_ncamera:
            write_yaml(self.__get_output_file_name(), sensors_dict,
                       default_flow_style, use_indent_dumper)
        else:
            ncameras_dict = sensors_dict['ncameras'][0]
            write_yaml(self.__get_output_file_name(), ncameras_dict,
                       default_flow_style, use_indent_dumper)
        print("Wrote " + str(self.__get_output_file_name()) + ".")

    def __get_output_file_name(self):
        if not self.output_file_name is None:
            return self.output_file_name
        elif self.__is_kalibr_format:
            if 'camchain-imucam' in self.input_file_name:
                if not self.to_ncamera:
                    return self.input_file_name.replace('camchain-imucam',
                                                        'sensors')
                else:
                    return self.input_file_name.replace('camchain-imucam',
                                                        'ncamera')
            else:
                return 'sensors.yaml'
        else:
            if 'sensors' in self.input_file_name:
                return self.input_file_name.replace('sensors',
                                                    'camchain-imucam')
            elif 'ncamera' in self.input_file_name:
                return self.input_file_name.replace('ncamera',
                                                    'camchain-imucam')
            else:
                return 'camchain-imucam.yaml'


if __name__ == "__main__":
    setup_yaml()
    parsed_args = parse_args()
    calib_config = CalibrationConfig(parsed_args)
    calib_config.write_converted_yaml()
